#include <gtest/gtest.h>
#include <string>
#include <iostream>
#include <unordered_map>
#include "framework/graph/operator.h"
#include "framework/graph/op_desc.h"
#include "framework/graph/compatible/operator_reg.h"
#include "framework/graph/graph.h"
#include "framework/graph/model.h"
#include "framework/graph/compute_graph.h"
#include "framework/graph/attr_value.h"
#include "framework/graph/utils/graph_utils.h"
#include "framework/graph/compatible/all_ops.h"
#include "framework/graph/core/infershape/graph_verify_util.h"
#include "framework/graph/core/infershape/graph_infershape_util.h"

using namespace std;
using namespace ge;
using namespace op;

class ge_test_infer_shape : public testing::Test {
protected:
    void SetUp()
    {
    }

    void TearDown()
    {
    }
};

/* test utility functions */

#define LOOP_VEC(v) for (unsigned i = 0; i < v.size(); i++)

#define SET_CONST_SHAPE(op, shape) \
        op.set_attr_value(shared_ptr<Tensor>(new Tensor(TensorDesc(shape))))

extern void SetInputShape(Operator op, string name, vector<int64_t> shape);/* {
    TensorDesc td = op.GetInputDesc(name);
    td.SetShape(Shape(shape));
    op.UpdateInputDesc(name, td);
}*/

extern void CheckOutputShape(Operator op, string name, vector<int64_t> shape); 
/*void CheckOutputShape(Operator op, string name, vector<int64_t> shape) {
    Shape s = op.GetOutputDesc(name).GetShape();

    EXPECT_EQ(s.GetDims().size(), shape.size());
    LOOP_VEC(shape) EXPECT_EQ(s.GetDim(i), shape[i]);

    cout << "CHECKED [";
    LOOP_VEC(shape) cout << s.GetDim(i) << ", ";
    cout << "] == [";
    LOOP_VEC(shape) cout << shape[i] << ", ";
    cout << "]\n";
}*/

#define TEST_OPERATOR(op_, input_shapes, output_shapes) \
    { \
        auto op = op_; \
        for (auto input_pair: input_shapes) \
            SetInputShape(op, input_pair.first, input_pair.second); \
        op.InferShapeAndType(); \
        for (auto output_pair: output_shapes) \
            CheckOutputShape(op, output_pair.first, output_pair.second); \
    }

/* unit tests */

/*TEST_F(ge_test_infer_shape, ge_test_infer_shape_resnet_test)
{
    auto data = op::Const("data");
    SET_CONST_SHAPE(data, Shape({1, 3, 224, 224}));
    auto weights = op::Const("weights");
    SET_CONST_SHAPE(weights, Shape({64, 3, 7, 7}));
    auto bias = op::Const("bias");
    SET_CONST_SHAPE(bias, Shape({64}));

    auto conv1 = op::Convolution("conv1")
                    .set_input_x(data)
                    .set_input_w(weights)
                    .set_input_b(bias)
                    .set_attr_stride({1,1})
                    .set_attr_pad({3,3,3,3});


    auto scale = op::Const("scale");
    SET_CONST_SHAPE(scale, Shape({64}));
    auto b = op::Const("b");
    SET_CONST_SHAPE(b, Shape({64}));
    auto mean = op::Const("mean");
    SET_CONST_SHAPE(mean, Shape({64}));
    auto variance = op::Const("variance");
    SET_CONST_SHAPE(variance, Shape({64}));
    auto fusedbatchnorm = op::FusedBatchNorm("fusedbatchnorm1")
                            .set_input_x(conv1.get_out_y())
                            .set_input_scale(scale)
                            .set_input_b(b)
                            .set_input_mean(mean)
                            .set_input_variance(variance);

    auto relu = op::ReLU("relu")
                    .set_input_x(fusedbatchnorm.get_out_y());

    auto weights2 = op::Const("weights2");
    SET_CONST_SHAPE(weights2, Shape({64, 64, 7, 7}));

    auto conv2 = op::Conv2D("conv2")
                    .set_input_x(relu)
                    .set_input_w(weights2)
                    .set_attr_stride({1,1})
                    .set_attr_pad({3,3,3,3});

    auto eltwise = op::Eltwise("eltwise")
                    .set_input_x1(relu)
                    .set_input_x2(conv2);

    auto inputs = vector<Operator>{data};
    auto graph = GraphUtils::CreateGraphFromOperator("test_graph", inputs);
    EXPECT_EQ(ge::GraphVerifyUtil::Verify(graph), GRAPH_SUCCESS);
    EXPECT_EQ(GraphInfershapeUtil::InferShape(graph), GRAPH_SUCCESS);
    CheckOutputShape(data, "y", {1, 3, 224, 224});
    CheckOutputShape(weights, "y", {64, 3, 7, 7});
    CheckOutputShape(bias, "y", {64});
    CheckOutputShape(conv1, "y", {1, 64, 224, 224});

    CheckOutputShape(scale, "y", {64});
    CheckOutputShape(b, "y", {64});
    CheckOutputShape(mean, "y", {64});
    CheckOutputShape(variance, "y", {64});
    CheckOutputShape(fusedbatchnorm, "y", {1, 64, 224, 224});

    CheckOutputShape(relu, "y", {1, 64, 224, 224});

    CheckOutputShape(weights2, "y", {64, 64, 7, 7});
    CheckOutputShape(conv2, "y", {1, 64, 224, 224});
    CheckOutputShape(eltwise, "y", {1, 64, 224, 224});
    //GraphUtils::SaveToDotFile(graph, "test.dot");
    //GraphUtils::PrintComputeGraph(graph);
}*/


TEST_F(ge_test_infer_shape, softmax)
{
    auto i = unordered_map<string, vector<int64_t>>({
            {"x", {2, 4, 2, 1}},
            });
    auto o = unordered_map<string, vector<int64_t>>({
            {"y", {2, 4, 2, 1}},
            });

    auto test_op = op::Softmax("test_op")
        .set_attr_axis(1);

    TEST_OPERATOR(test_op, i, o);
}

TEST_F(ge_test_infer_shape, elmtwise)
{
    auto i = unordered_map<string, vector<int64_t>>({
            {"x", {2, 3, 4, 5}}
            });
    auto o = unordered_map<string, vector<int64_t>>({
            {"y", {2, 3, 4, 5}}
            });
    auto i1 = unordered_map<string, vector<int64_t>>({
            {"cond", {1}},
            {"input", {2, 3, 4, 5}}
            });
    auto o1 = unordered_map<string, vector<int64_t>>({
            {"output", {2, 3, 4, 5}},
            });
    auto i2 = unordered_map<string, vector<int64_t>>({
            {"x1", {2, 3, 4, 5}}
            });

//    auto sqrt = op::Sqrt("sqrt1");
//    auto square = op::Square("square1");
//    auto ssd_normalize = op::SSDNormalize("ssd_normalize1");
//    auto tanh = op::Tanh("tanh");
//    auto relu = op::ReLU("relu1");
//    auto relu6 = op::ReLU6("relu61");
    auto eltwise = op::Eltwise("eltwise1");
    auto activation = op::Activation("activation1");

//    TEST_OPERATOR(sqrt, i, o);
//    TEST_OPERATOR(square, i, o);
//    TEST_OPERATOR(ssd_normalize, i, o);
//    TEST_OPERATOR(tanh, i, o);
//    TEST_OPERATOR(relu, i, o);
//    TEST_OPERATOR(relu6, i, o);
    TEST_OPERATOR(eltwise, i2, o);
    TEST_OPERATOR(activation, i, o);
}

TEST_F(ge_test_infer_shape, const_op)
{
    auto i = unordered_map<string, vector<int64_t>>({});
    auto o = unordered_map<string, vector<int64_t>>({
            {"y", {64, 3, 7, 7}}
            });
    auto const_op = op::Const("Const1");
    TensorDesc weights_desc(Shape({64, 3, 7, 7}));
    Tensor weights_tensor(weights_desc);
    const_op.set_attr_value(std::make_shared<Tensor>(weights_tensor));
    TEST_OPERATOR(const_op, i, o);
}

TEST_F(ge_test_infer_shape, convolution)
{
    auto i = unordered_map<string, vector<int64_t>>({
                                                            {"x", {32, 3, 224, 224}},
                                                            {"w", {64, 3, 7, 7}}
                                                    });
    auto o = unordered_map<string, vector<int64_t>>({
                                                            {"y", {32, 64, 218, 218}}
                                                    });
    auto o2 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 64, 224, 224}}
                                                     });
    auto o3 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 64, 112, 112}}
                                                     });
    auto o4 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 64, 56, 56}}
                                                     });
    auto o5 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 64, 73, 73}}
                                                     });

    auto conv2d_1 = op::Convolution("conv1");
    auto conv2d_2 = op::Convolution("conv2")
            .set_attr_pad({3, 3, 3 ,3});
    auto conv2d_3 = op::Convolution("conv3")
            .set_attr_pad({3, 3, 3 ,3})
            .set_attr_stride({2, 2});
    auto conv2d_4 = op::Convolution("conv4")
            .set_attr_pad_mode(6)
            .set_attr_stride({4, 4});
    auto conv2d_5 = op::Convolution("conv5")
            .set_attr_pad_mode(5)
            .set_attr_stride({3, 3});

    TEST_OPERATOR(conv2d_1, i, o);
    TEST_OPERATOR(conv2d_2, i, o2);
    TEST_OPERATOR(conv2d_3, i, o3);
    TEST_OPERATOR(conv2d_4, i, o4);
    TEST_OPERATOR(conv2d_5, i, o5);
}

TEST_F(ge_test_infer_shape, fc)
{
    auto in = unordered_map<string, vector<int64_t>>({{"x", {32, 2048, 1, 1}},
                                                      {"w", {12, 2048, 1, 1}}});
    auto out = unordered_map<string, vector<int64_t>>({{"y", {32, 12, 1, 1}}});
    auto fc = op::FullConnection("fc");
    TEST_OPERATOR(fc, in, out);
}

TEST_F(ge_test_infer_shape, pooling)
{
    auto i = unordered_map<string, vector<int64_t>>({
                                                            {"x", {32, 3, 224, 224}}
                                                    });
    auto o1 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 3, 109, 109}}
                                                     });
    auto o2 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 3, 109, 109}}
                                                     });
    auto o3 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 3, 112, 112}}
                                                     });
    auto o4 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 3, 1, 1}}
                                                     });
    auto o5 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 3, 110, 110}}
                                                     });

    auto o6 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 3, 116, 116}}
                                                     });
    auto o7 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 3, 109, 109}}
                                                     });
    auto o8 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {32, 3, 109, 109}}
                                                     });

    auto pooling1 = op::Pooling("Pooling1")
            .set_attr_window({7, 7})
            .set_attr_pad_mode(5)
            .set_attr_pad({3, 3, 3, 3})
            .set_attr_stride({2, 2});
    auto pooling2 = op::Pooling("Pooling2")
            .set_attr_window({7, 7})
            .set_attr_pad_mode(5)
            .set_attr_stride({2, 2});
    auto pooling3 = op::Pooling("Pooling3")
            .set_attr_window({7, 7})
            .set_attr_pad_mode(6)
            .set_attr_stride({2, 2});
    auto pooling4 = op::Pooling("Pooling4")
            .set_attr_global_pooling(true)
            .set_attr_window({7, 7})
            .set_attr_pad_mode(5)
            .set_attr_stride({2, 2});
    auto pooling5 = op::Pooling("Pooling5")
            .set_attr_global_pooling(false)
            .set_attr_window({7, 7})
            .set_attr_pad_mode(4)
            .set_attr_pad({0,0,0,0})
            .set_attr_stride({2, 2})
            .set_attr_data_mode(0)
            .set_attr_ceil_mode(1);

    auto pooling6 = op::Pooling("Pooling6")
            .set_attr_window({7, 7})
            .set_attr_pad_mode(0)
            .set_attr_stride({2, 2})
            .set_attr_pad({7,7,7,7})
            .set_attr_data_mode(0)
            .set_attr_ceil_mode(1);

    auto pooling7 = op::Pooling("Pooling7")
            .set_attr_window({7, 7})
            .set_attr_pad_mode(0)
            .set_attr_pad({0,0,0,0})
            .set_attr_stride({2, 2})
            .set_attr_data_mode(0)
            .set_attr_ceil_mode(0);

    auto pooling8 = op::Pooling("Pooling8")
            .set_attr_window({7, 7})
            .set_attr_pad_mode(0)
            .set_attr_pad({0,0,0,0})
            .set_attr_stride({2, 2})
            .set_attr_data_mode(1);


    TEST_OPERATOR(pooling1, i, o1);
    TEST_OPERATOR(pooling2, i, o2);
    TEST_OPERATOR(pooling3, i, o3);
    TEST_OPERATOR(pooling4, i, o4);
    TEST_OPERATOR(pooling5, i, o5);
    TEST_OPERATOR(pooling6, i, o6);
    TEST_OPERATOR(pooling7, i, o7);
    TEST_OPERATOR(pooling8, i, o8);
}

TEST_F(ge_test_infer_shape, convolution1){

    auto i = unordered_map<string, vector<int64_t>>({
                                                            {"x", {32, 3, 224, 224}},
                                                            {"w", {64, 3, 7, 7}}
                                                    });

    auto o6 = unordered_map<string, vector<int64_t>>({
                                                             {"y", {}}
                                                     });

    auto i6 = unordered_map<string, vector<int64_t>>({
                                                             {"x", {32, 3, 224, 224}},
                                                             {"w", {4, 2, 7, 7}}
                                                     });

    auto conv2d_6 = op::Convolution("conv6")
            .set_attr_group(2);

    auto conv2d_7 = op::Convolution("conv7")
            .set_attr_stride({4, 4,1});

    auto conv2d_8 = op::Convolution("conv8")
            .set_attr_pad_mode(0)
            .set_attr_pad({3, 3, 3 ,3,1});
    auto conv2d_9 = op::Convolution("conv9")
            .set_attr_pad_mode(10);
    auto conv2d_5 = op::Convolution("conv5");

    TEST_OPERATOR(conv2d_6, i, o6);
    TEST_OPERATOR(conv2d_7, i, o6);
    TEST_OPERATOR(conv2d_5, i6, o6);
    TEST_OPERATOR(conv2d_8, i, o6);
    //TEST_OPERATOR(conv2d_9, i, o6);
}
TEST_F(ge_test_infer_shape, codesafe){

    auto i = unordered_map<string, vector<int64_t>>({
            {"x", {32, 3, 224, 224}},
            {"w", {64,  3, 7, 7}}
            });

    auto o = unordered_map<string, vector<int64_t>>({
            {"y", {}}
            });

    auto conv2d_1 = op::Convolution("conv1")
         .set_attr_stride({0, 1});
    auto conv2d_2 = op::Convolution("conv2")
         .set_attr_stride({1, 0});
    auto conv2d_3 = op::Convolution("conv3")
         .set_attr_stride({0, 0});

    TEST_OPERATOR(conv2d_1, i, o);
    TEST_OPERATOR(conv2d_2, i, o);
    TEST_OPERATOR(conv2d_3, i, o);

    auto i1 = unordered_map<string, vector<int64_t>>({
            {"x", {32, 3, 224, 224}}
            });
    auto pooling1 = op::Pooling("Pooling1")
        .set_attr_window({7, 7})
        .set_attr_pad_mode(5)
        .set_attr_pad({3, 3, 3, 3})
        .set_attr_stride({0, 2});
    auto pooling2 = op::Pooling("Pooling2")
        .set_attr_window({7, 7})
        .set_attr_pad_mode(5)
        .set_attr_pad({3, 3, 3, 3})
        .set_attr_stride({2, 0});
    auto pooling3 = op::Pooling("Pooling3")
        .set_attr_window({7, 7})
        .set_attr_pad_mode(5)
        .set_attr_pad({3, 3, 3, 3})
        .set_attr_stride({0, 0});


    TEST_OPERATOR(pooling1, i1, o);
    TEST_OPERATOR(pooling2, i1, o);
    TEST_OPERATOR(pooling3, i1, o);
}

